

import torch

from src.models.model_utils import choose_model


# ---- check if trainings run successfully ---

def test_lenet_mnist_reference():
    '''
    Tests if one epoch of lenet on mnist with reference model runs
    '''

    NN = choose_model("lenet5", baseline=True, tucker=False, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 1, 28, 28], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False
    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_lenet_mnist_tucker_adaptive():
    '''
    Tests if one epoch of lenet on mnist with tucker DLRT rank adaption
    '''
    NN = choose_model("lenet5", baseline=False, tucker=True, mat_dlrt=False, adaptive=True, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 1, 28, 28], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False
    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_lenet_mnist_tucker_fr():
    '''
    Tests if one epoch of lenet on mnist with fixed rank Tucker DLRT
    '''
    NN = choose_model("lenet5", baseline=False, tucker=True, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 1, 28, 28], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_Lenet5_cifar10_reference():
    '''
    Tests if one epoch of lenet on mnist with reference model runs
    '''

    NN = choose_model("lenet5_rgb", baseline=True, tucker=False, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_Lenet5_cifar10_tucker_adaptive():
    '''
    Tests if one epoch of lenet on mnist with tucker DLRT rank adaption
    '''
    NN = choose_model("lenet5_rgb", baseline=False, tucker=True, mat_dlrt=False, adaptive=True, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_Lenet5_cifar10_tucker_fr():
    '''
    Tests if one epoch of lenet on mnist with fixed rank Tucker DLRT
    '''
    NN = choose_model("lenet5_rgb", baseline=False, tucker=True, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_vgg16_cifar10_reference():
    '''
    Tests if one epoch of lenet on mnist with reference model runs
    '''

    NN = choose_model("vgg16", baseline=True, tucker=False, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_vgg16_cifar10_tucker_adaptive():
    '''
    Tests if one epoch of lenet on mnist with tucker DLRT rank adaption
    '''
    NN = choose_model("vgg16", baseline=False, tucker=True, mat_dlrt=False, adaptive=True, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_vgg16_cifar10_tucker_fr():
    '''
    Tests if one epoch of lenet on mnist with fixed rank Tucker DLRT
    '''
    NN = choose_model("vgg16", baseline=False, tucker=True, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_alexnet_cifar10_reference():
    '''
    Tests if one epoch of lenet on mnist with reference model runs
    '''

    NN = choose_model("alexnet", baseline=True, tucker=False, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_alexnet_cifar10_tucker_adaptive():
    '''
    Tests if one epoch of lenet on mnist with tucker DLRT rank adaption
    '''
    NN = choose_model("alexnet", baseline=False, tucker=True, mat_dlrt=False, adaptive=True, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference


def test_alexnet_cifar10_tucker_fr():
    '''
    Tests if one epoch of lenet on mnist with fixed rank Tucker DLRT
    '''
    NN = choose_model("alexnet", baseline=False, tucker=True, mat_dlrt=False, adaptive=False, tau=0.1, device="cpu",
                      chain_init=False)
    test_input = torch.ones([2, 3, 32, 32], dtype=torch.float32)

    # --- test inference ---
    with torch.no_grad():
        test_output = NN(test_input).detach()

    inference = False

    if not torch.isnan(test_output).numpy().any():
        inference = True  # inference ran through
    assert inference
